{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Requirement already satisfied: gnn in /usr/local/lib/python3.7/site-packages (1.1.5)\r\n"
     ]
    }
   ],
   "source": [
    "!pip3 install gnn"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import tensorflow as tf\n",
    "import numpy as np\n",
    "import gnn.gnn_utils as gnn_utils"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "# reading train, validation dataset\n",
    "data_path = \"./data\"\n",
    "set_name = \"sub_15_7_200\"\n",
    "############# training set ################\n",
    "inp, arcnode, nodegraph, nodein, labels, _ = gnn_utils.set_load_general(data_path, \"train\", set_name=set_name)\n",
    "inp = [a[:, 1:] for a in inp]\n",
    "############ validation set #############\n",
    "inp_val, arcnode_val, nodegraph_val, nodein_val, labels_val, _ = gnn_utils.set_load_general(data_path, \"validation\", set_name=set_name)\n",
    "inp_val = [a[:, 1:] for a in inp_val]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "input_dim = len(inp[0][0])\n",
    "state_dim = 10\n",
    "output_dim = 2\n",
    "state_threshold = 0.001\n",
    "max_iter = 50"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "def f_w(inp):\n",
    "    with tf.variable_scope('State_net'):\n",
    "        layer1 = tf.layers.dense(inp, 5, activation=tf.nn.sigmoid)\n",
    "        layer2 = tf.layers.dense(layer1, state_dim, activation=tf.nn.sigmoid)\n",
    "        return layer2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "def g_w(inp):\n",
    "    with tf.variable_scope('Output_net'):\n",
    "        layer1 = tf.layers.dense(inp, 5, activation=tf.nn.sigmoid)\n",
    "        layer2 = tf.layers.dense(layer1, output_dim, activation=None)\n",
    "        return layer2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "#init input placeholder\n",
    "tf.reset_default_graph()\n",
    "comp_inp = tf.placeholder(tf.float32, shape=(None, input_dim), name=\"input\")\n",
    "y = tf.placeholder(tf.float32, shape=(None, output_dim), name=\"target\")\n",
    "\n",
    "# state(t) & state(t-1)\n",
    "state = tf.placeholder(tf.float32, shape=(None, state_dim), name=\"state\")\n",
    "state_old = tf.placeholder(tf.float32, shape=(None, state_dim), name=\"old_state\")\n",
    "\n",
    "# arch-node conversion matrix\n",
    "ArcNode = tf.sparse_placeholder(tf.float32, name=\"ArcNode\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "def convergence(a, state, old_state, k):\n",
    "    with tf.variable_scope('Convergence'):\n",
    "        # assign current state to old state\n",
    "        old_state = state\n",
    "        \n",
    "        # 获取子结点上一个时刻的状态\n",
    "        # grub states of neighboring node \n",
    "        gat = tf.gather(old_state, tf.cast(a[:, 0], tf.int32))\n",
    "        \n",
    "        # 去除第一列，即子结点的id\n",
    "        # slice to consider only label of the node and that of it's neighbor \n",
    "        # sl = tf.slice(a, [0, 1], [tf.shape(a)[0], tf.shape(a)[1] - 1])\n",
    "        # equivalent code\n",
    "        sl = a[:, 1:]\n",
    "        \n",
    "        # 将子结点上一个时刻的状态放到最后一列\n",
    "        # concat with retrieved state\n",
    "        inp = tf.concat([sl, gat], axis=1)\n",
    "\n",
    "        # evaluate next state and multiply by the arch-node conversion matrix to obtain per-node states\n",
    "        #计算子结点对父结点状态的贡献\n",
    "        layer1 = f_w(inp)\n",
    "        #聚合子结点对父结点状态的贡献，得到当前时刻的父结点的状态\n",
    "        state = tf.sparse_tensor_dense_matmul(ArcNode, layer1)\n",
    "\n",
    "        # update the iteration counter\n",
    "        k = k + 1\n",
    "    return a, state, old_state, k"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "def condition(a, state, old_state, k):\n",
    "    # evaluate condition on the convergence of the state\n",
    "    with tf.variable_scope('condition'):\n",
    "        # 检查当前状态和上一个时刻的状态的欧式距离是否小于阈值\n",
    "        # evaluate distance by state(t) and state(t-1)\n",
    "        outDistance = tf.sqrt(tf.reduce_sum(tf.square(tf.subtract(state, old_state)), 1) + 1e-10)\n",
    "        # vector showing item converged or not (given a certain threshold)\n",
    "        checkDistanceVec = tf.greater(outDistance, state_threshold)\n",
    "        \n",
    "        c1 = tf.reduce_any(checkDistanceVec)\n",
    "        \n",
    "        # 是否达到最大迭代次数\n",
    "        c2 = tf.less(k, max_iter)\n",
    "\n",
    "    return tf.logical_and(c1, c2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From /usr/local/lib/python3.7/site-packages/tensorflow/python/ops/control_flow_ops.py:423: colocate_with (from tensorflow.python.framework.ops) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Colocations handled automatically by placer.\n",
      "WARNING:tensorflow:From <ipython-input-5-23a797daad94>:3: dense (from tensorflow.python.layers.core) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Use keras.layers.dense instead.\n"
     ]
    }
   ],
   "source": [
    "# 迭代计算，直到状态达到稳定状态\n",
    "# compute state\n",
    "with tf.variable_scope('Loop'):\n",
    "    k = tf.constant(0)\n",
    "    res, st, old_st, num = tf.while_loop(condition, convergence,\n",
    "                                         [comp_inp, state, state_old, k])\n",
    "    # 计算结点的output\n",
    "    out = g_w(st)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From /usr/local/lib/python3.7/site-packages/tensorflow/python/ops/array_grad.py:425: to_int32 (from tensorflow.python.ops.math_ops) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Use tf.cast instead.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/usr/local/lib/python3.7/site-packages/tensorflow/python/ops/gradients_impl.py:110: UserWarning: Converting sparse IndexedSlices to a dense Tensor of unknown shape. This may consume a large amount of memory.\n",
      "  \"Converting sparse IndexedSlices to a dense Tensor of unknown shape. \"\n"
     ]
    }
   ],
   "source": [
    "loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(labels=y, logits=out))\n",
    "accuracy = tf.reduce_mean(tf.cast(tf.equal(tf.argmax(out, 1), tf.argmax(y, 1)), dtype=tf.float32))\n",
    "optimizer = tf.train.AdamOptimizer(0.001)\n",
    "grads = optimizer.compute_gradients(loss)\n",
    "train_op = optimizer.apply_gradients(grads, name='train_op')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "iter 0\t training loss: 0.69009537,\t training accuracy: 0.526,\t validation loss: 0.68896145,\t validation accuracy: 0.5416667\n",
      "iter 100\t training loss: 0.6728105,\t training accuracy: 0.5743333,\t validation loss: 0.6688947,\t validation accuracy: 0.58166665\n",
      "iter 200\t training loss: 0.65165895,\t training accuracy: 0.6206667,\t validation loss: 0.6487902,\t validation accuracy: 0.6196667\n",
      "iter 300\t training loss: 0.6241475,\t training accuracy: 0.63133335,\t validation loss: 0.6229102,\t validation accuracy: 0.63133335\n",
      "iter 400\t training loss: 0.5859205,\t training accuracy: 0.67366666,\t validation loss: 0.58706313,\t validation accuracy: 0.6626667\n",
      "iter 500\t training loss: 0.54777163,\t training accuracy: 0.7073333,\t validation loss: 0.54999274,\t validation accuracy: 0.70666665\n",
      "iter 600\t training loss: 0.5188507,\t training accuracy: 0.7403333,\t validation loss: 0.5230317,\t validation accuracy: 0.74233335\n",
      "iter 700\t training loss: 0.49069414,\t training accuracy: 0.7643333,\t validation loss: 0.49758184,\t validation accuracy: 0.76166666\n",
      "iter 800\t training loss: 0.47247437,\t training accuracy: 0.7776667,\t validation loss: 0.4826656,\t validation accuracy: 0.76933336\n",
      "iter 900\t training loss: 0.4598218,\t training accuracy: 0.78566664,\t validation loss: 0.47266373,\t validation accuracy: 0.77666664\n",
      "iter 1000\t training loss: 0.44883335,\t training accuracy: 0.793,\t validation loss: 0.46385023,\t validation accuracy: 0.7816667\n",
      "iter 1100\t training loss: 0.4374816,\t training accuracy: 0.79966664,\t validation loss: 0.45490763,\t validation accuracy: 0.789\n",
      "iter 1200\t training loss: 0.42728108,\t training accuracy: 0.80833334,\t validation loss: 0.44681072,\t validation accuracy: 0.79333335\n",
      "iter 1300\t training loss: 0.41820458,\t training accuracy: 0.81633335,\t validation loss: 0.44001114,\t validation accuracy: 0.79866666\n",
      "iter 1400\t training loss: 0.40614474,\t training accuracy: 0.8236667,\t validation loss: 0.4295776,\t validation accuracy: 0.80466664\n",
      "iter 1500\t training loss: 0.39452788,\t training accuracy: 0.82666665,\t validation loss: 0.41836622,\t validation accuracy: 0.8053333\n",
      "iter 1600\t training loss: 0.38394842,\t training accuracy: 0.831,\t validation loss: 0.4080569,\t validation accuracy: 0.81133336\n",
      "iter 1700\t training loss: 0.37435803,\t training accuracy: 0.834,\t validation loss: 0.39910385,\t validation accuracy: 0.81266665\n",
      "iter 1800\t training loss: 0.36678034,\t training accuracy: 0.8323333,\t validation loss: 0.3925525,\t validation accuracy: 0.81233335\n",
      "iter 1900\t training loss: 0.36079496,\t training accuracy: 0.83566666,\t validation loss: 0.38773572,\t validation accuracy: 0.81233335\n",
      "iter 2000\t training loss: 0.35580721,\t training accuracy: 0.83966666,\t validation loss: 0.38400057,\t validation accuracy: 0.8156667\n",
      "iter 2100\t training loss: 0.3513445,\t training accuracy: 0.8436667,\t validation loss: 0.38036272,\t validation accuracy: 0.81866664\n",
      "iter 2200\t training loss: 0.34720415,\t training accuracy: 0.84466666,\t validation loss: 0.37664217,\t validation accuracy: 0.8196667\n",
      "iter 2300\t training loss: 0.34326342,\t training accuracy: 0.8476667,\t validation loss: 0.37265036,\t validation accuracy: 0.826\n",
      "iter 2400\t training loss: 0.3392609,\t training accuracy: 0.84966666,\t validation loss: 0.3679434,\t validation accuracy: 0.82766664\n",
      "iter 2500\t training loss: 0.33505964,\t training accuracy: 0.84966666,\t validation loss: 0.3636952,\t validation accuracy: 0.832\n",
      "iter 2600\t training loss: 0.3306019,\t training accuracy: 0.852,\t validation loss: 0.3597662,\t validation accuracy: 0.8333333\n",
      "iter 2700\t training loss: 0.32598382,\t training accuracy: 0.85333335,\t validation loss: 0.35593197,\t validation accuracy: 0.83633333\n",
      "iter 2800\t training loss: 0.32101524,\t training accuracy: 0.8576667,\t validation loss: 0.3524064,\t validation accuracy: 0.838\n",
      "iter 2900\t training loss: 0.31531197,\t training accuracy: 0.85933334,\t validation loss: 0.34848136,\t validation accuracy: 0.842\n",
      "iter 3000\t training loss: 0.30852255,\t training accuracy: 0.861,\t validation loss: 0.34406328,\t validation accuracy: 0.846\n",
      "iter 3100\t training loss: 0.30055603,\t training accuracy: 0.865,\t validation loss: 0.33848092,\t validation accuracy: 0.84866667\n",
      "iter 3200\t training loss: 0.29329032,\t training accuracy: 0.8656667,\t validation loss: 0.33375865,\t validation accuracy: 0.8526667\n",
      "iter 3300\t training loss: 0.2928349,\t training accuracy: 0.867,\t validation loss: 0.3330968,\t validation accuracy: 0.8513333\n",
      "iter 3400\t training loss: 0.2850153,\t training accuracy: 0.86733335,\t validation loss: 0.32819343,\t validation accuracy: 0.85433334\n",
      "iter 3500\t training loss: 0.2816742,\t training accuracy: 0.8693333,\t validation loss: 0.32585967,\t validation accuracy: 0.8553333\n",
      "iter 3600\t training loss: 0.27845147,\t training accuracy: 0.87,\t validation loss: 0.32543993,\t validation accuracy: 0.85933334\n",
      "iter 3700\t training loss: 0.2794321,\t training accuracy: 0.87166667,\t validation loss: 0.32390258,\t validation accuracy: 0.859\n",
      "iter 3800\t training loss: 0.27215356,\t training accuracy: 0.87266666,\t validation loss: 0.3190985,\t validation accuracy: 0.86366665\n",
      "iter 3900\t training loss: 0.2701055,\t training accuracy: 0.8746667,\t validation loss: 0.31514907,\t validation accuracy: 0.8663333\n",
      "iter 4000\t training loss: 0.26721725,\t training accuracy: 0.87666667,\t validation loss: 0.31546187,\t validation accuracy: 0.8663333\n",
      "iter 4100\t training loss: 0.2655896,\t training accuracy: 0.8756667,\t validation loss: 0.3157615,\t validation accuracy: 0.867\n",
      "iter 4200\t training loss: 0.26295316,\t training accuracy: 0.8793333,\t validation loss: 0.31422254,\t validation accuracy: 0.866\n",
      "iter 4300\t training loss: 0.26088497,\t training accuracy: 0.88133335,\t validation loss: 0.31307566,\t validation accuracy: 0.8666667\n",
      "iter 4400\t training loss: 0.2588902,\t training accuracy: 0.88233334,\t validation loss: 0.31210607,\t validation accuracy: 0.86866665\n",
      "iter 4500\t training loss: 0.25694385,\t training accuracy: 0.882,\t validation loss: 0.3112569,\t validation accuracy: 0.87\n",
      "iter 4600\t training loss: 0.255024,\t training accuracy: 0.885,\t validation loss: 0.31042317,\t validation accuracy: 0.8696667\n",
      "iter 4700\t training loss: 0.2529742,\t training accuracy: 0.88566667,\t validation loss: 0.30942827,\t validation accuracy: 0.87\n",
      "iter 4800\t training loss: 0.25062183,\t training accuracy: 0.88766664,\t validation loss: 0.30894735,\t validation accuracy: 0.87\n",
      "iter 4900\t training loss: 0.24865916,\t training accuracy: 0.8883333,\t validation loss: 0.30784124,\t validation accuracy: 0.872\n"
     ]
    }
   ],
   "source": [
    "###train model####\n",
    "num_epoch = 5000\n",
    "# 训练集placeholder输入\n",
    "arcnode_train = tf.SparseTensorValue(indices=arcnode[0].indices, values=arcnode[0].values,\n",
    "                                    dense_shape=arcnode[0].dense_shape)\n",
    "fd_train = {comp_inp: inp[0], state: np.zeros((arcnode[0].dense_shape[0], state_dim)),\n",
    "          state_old: np.ones((arcnode[0].dense_shape[0], state_dim)),\n",
    "          ArcNode: arcnode_train, y: labels}\n",
    "#验证集placeholder输入\n",
    "arcnode_valid = tf.SparseTensorValue(indices=arcnode_val[0].indices, values=arcnode_val[0].values,\n",
    "                                dense_shape=arcnode_val[0].dense_shape)\n",
    "fd_valid = {comp_inp: inp_val[0], state: np.zeros((arcnode_val[0].dense_shape[0], state_dim)),\n",
    "      state_old: np.ones((arcnode_val[0].dense_shape[0], state_dim)),\n",
    "      ArcNode: arcnode_valid, y: labels_val}\n",
    "\n",
    "with tf.Session() as sess:\n",
    "    sess.run(tf.global_variables_initializer())\n",
    "    sess.run(tf.local_variables_initializer())\n",
    "    for i in range(0, num_epoch):\n",
    "        _, loss_val, accuracy_val = sess.run(\n",
    "                    [train_op, loss, accuracy],\n",
    "                    feed_dict=fd_train)\n",
    "        if i % 100 == 0:\n",
    "\n",
    "            loss_valid_val, accuracy_valid_val = sess.run(\n",
    "                    [loss, accuracy],\n",
    "                    feed_dict=fd_valid)\n",
    "            print(\"iter %s\\t training loss: %s,\\t training accuracy: %s,\\t validation loss: %s,\\t validation accuracy: %s\" % \n",
    "                  (i, loss_val, accuracy_val, loss_valid_val, accuracy_valid_val))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
